{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import coremltools as ct\n",
    "import math\n",
    "from repvit_sam.utils.transforms import ResizeLongestSide\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "def show_mask(mask, ax):\n",
    "    color = np.array([30/255, 144/255, 255/255, 0.6])\n",
    "    h, w = mask.shape[-2:]\n",
    "    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
    "    ax.imshow(mask_image)\n",
    "    \n",
    "def show_points(coords, labels, ax, marker_size=375):\n",
    "    pos_points = coords[labels==1]\n",
    "    neg_points = coords[labels==0]\n",
    "    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
    "    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   \n",
    "\n",
    "def preprocess(x, img_size=1024):\n",
    "    \"\"\"Normalize pixel values and pad to a square input.\"\"\"\n",
    "    # Normalize colors\n",
    "    transform = ResizeLongestSide(img_size)\n",
    "    x = transform.apply_image(x)\n",
    "    x = torch.as_tensor(x)\n",
    "    x = x.permute(2, 0, 1).contiguous()\n",
    "\n",
    "    pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)\n",
    "    pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)\n",
    "    x = (x - pixel_mean) / pixel_std\n",
    "\n",
    "    # Pad\n",
    "    h, w = x.shape[-2:]\n",
    "    padh = img_size - h\n",
    "    padw = img_size - w\n",
    "    x = F.pad(x, (0, padw, 0, padh))\n",
    "    return x, transform\n",
    "\n",
    "def postprocess(raw_image, masks):\n",
    "    def resize_longest_image_size(\n",
    "            input_image_size, longest_side: int\n",
    "        ):\n",
    "            scale = longest_side / max(input_image_size)\n",
    "            transformed_size = [int(math.floor(scale * each + 0.5)) for each in input_image_size]\n",
    "            return transformed_size\n",
    "\n",
    "    prepadded_size = resize_longest_image_size(raw_image.shape[:2], masks.shape[2])\n",
    "    masks = masks[..., : prepadded_size[0], : prepadded_size[1]]  # type: ignore\n",
    "\n",
    "    h, w = raw_image.shape[:2]\n",
    "    masks = F.interpolate(torch.tensor(masks), size=(h, w), mode=\"bilinear\", align_corners=False)\n",
    "    masks = masks > 0\n",
    "    return masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python3 ../scripts/export_coreml_encoder.py --resolution 1024 --model repvit --samckpt ../weights/repvit_sam.pt\n",
    "!python3 ../scripts/export_coreml_decoder.py --checkpoint ../weights/repvit_sam.pt --model-type repvit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = ct.models.MLModel('coreml/repvit_1024.mlpackage')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder = ct.models.MLModel('coreml/sam_decoder.mlpackage')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_image = cv2.imread('../../app/assets/picture3.jpg')\n",
    "raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)\n",
    "image, transform = preprocess(raw_image)\n",
    "image_embedding= list(encoder.predict({'x_1': image.numpy()[None, ...]}).values())[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = np.array([[553, 808]])\n",
    "input_label = np.array([1])\n",
    "\n",
    "coreml_coord = input_point[None, :, :].astype(np.float32)\n",
    "coreml_label = input_label[None, :].astype(np.float32)\n",
    "\n",
    "coreml_coord = transform.apply_coords(coreml_coord, raw_image.shape[:2]).astype(np.float32)\n",
    "\n",
    "coreml_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n",
    "coreml_has_mask_input = np.zeros(1, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ort_inputs = {\n",
    "    \"image_embeddings\": image_embedding,\n",
    "    \"point_coords\": coreml_coord,\n",
    "    \"point_labels\": coreml_label,\n",
    "    \"mask_input\": coreml_mask_input,\n",
    "    \"has_mask_input\": coreml_has_mask_input,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "low_res_logits, score, masks = decoder.predict(ort_inputs).values()\n",
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(raw_image)\n",
    "show_mask(postprocess(raw_image, masks), plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show() "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
